syntax = "proto3";

package jax;

enum PyTreeNodeType {
  PY_TREE_KIND_INVALID = 0;
  PY_TREE_KIND_LEAF = 1;
  PY_TREE_KIND_LIST = 2;
  PY_TREE_KIND_NONE = 3;
  PY_TREE_KIND_TUPLE = 4;
  PY_TREE_KIND_DICT = 5;
}

message DictKeysProto {
  repeated uint32 str_id = 1;
}

message PyTreeNodeDefProto {
  // Recovers the tree structure.
  uint32 arity = 1;
  // Node type.
  PyTreeNodeType type = 2;
  // Only set when type == DICT.
  DictKeysProto dict_keys = 3;
}

// A Pytree.
message PyTreeDefProto {
  repeated PyTreeNodeDefProto nodes = 1;
  // Extra strings.
  repeated string interned_strings = 2;
}
